{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install allennlp\n",
    "!git clone https://github.com/mhagiwara/realworldnlp.git\n",
    "%cd realworldnlp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from allennlp.data.iterators import BasicIterator\n",
    "from allennlp.data.vocabulary import Vocabulary\n",
    "from allennlp.models import Model\n",
    "from allennlp.modules.token_embedders import Embedding\n",
    "from allennlp.training.trainer import Trainer\n",
    "from torch.nn import CosineSimilarity\n",
    "from torch.nn import functional"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from examples.embeddings.word2vec import SkipGramReader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EMBEDDING_DIM = 256\n",
    "BATCH_SIZE = 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SkipGramModel(Model):\n",
    "    def __init__(self, vocab, embedding_in):\n",
    "        super().__init__(vocab)\n",
    "        self.embedding_in = embedding_in\n",
    "        self.linear = torch.nn.Linear(\n",
    "            in_features=EMBEDDING_DIM,\n",
    "            out_features=vocab.get_vocab_size('token_out'),\n",
    "            bias=False)\n",
    "\n",
    "    def forward(self, token_in, token_out):\n",
    "        embedded_in = self.embedding_in(token_in)\n",
    "        logits = self.linear(embedded_in)\n",
    "        loss = functional.cross_entropy(logits, token_out)\n",
    "\n",
    "        return {'loss': loss}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_related(token: str, embedding: Model, vocab: Vocabulary, num_synonyms: int = 10):\n",
    "    \"\"\"Given a token, return a list of top N most similar words to the token.\"\"\"\n",
    "    token_id = vocab.get_token_index(token, 'token_in')\n",
    "    token_vec = embedding.weight[token_id]\n",
    "    cosine = CosineSimilarity(dim=0)\n",
    "    sims = Counter()\n",
    "\n",
    "    for index, token in vocab.get_index_to_token_vocabulary('token_in').items():\n",
    "        sim = cosine(token_vec, embedding.weight[index]).item()\n",
    "        sims[token] = sim\n",
    "\n",
    "    return sims.most_common(num_synonyms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reader = SkipGramReader()\n",
    "text8 = reader.read('https://realworldnlpbook.s3.amazonaws.com/data/text8/text8')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(text8)\n",
    "text8 = text8[:1000000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = Vocabulary.from_instances(\n",
    "    text8, min_count={'token_in': 5, 'token_out': 5})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterator = BasicIterator(batch_size=BATCH_SIZE)\n",
    "iterator.index_with(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_in = Embedding(num_embeddings=vocab.get_vocab_size('token_in'),\n",
    "                         embedding_dim=EMBEDDING_DIM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SkipGramModel(vocab=vocab,\n",
    "                      embedding_in=embedding_in)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(model=model,\n",
    "                  optimizer=optimizer,\n",
    "                  iterator=iterator,\n",
    "                  train_dataset=text8,\n",
    "                  num_epochs=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(get_related('one', embedding_in, vocab))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(get_related('december', embedding_in, vocab))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
